// Copyright © 2023-2024 Apple Inc.

#pragma once

#include <exception>
#include <variant>

#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/dtype.h"
#include "mlx/stream.h"

namespace mlx::core {

using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
Stream to_stream(StreamOrDevice s);
Stream to_stream(StreamOrDevice s, Device default_);

struct StreamContext {
 public:
  StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) {
    if (std::holds_alternative<std::monostate>(s)) {
      throw std::runtime_error(
          "[StreamContext] Invalid argument, please specify a stream or device.");
    }
    auto _s = to_stream(s);
    set_default_device(_s.device);
    set_default_stream(_s);
  }

  ~StreamContext() {
    set_default_device(_stream.device);
    set_default_stream(_stream);
  }

 private:
  Stream _stream;
};

struct PrintFormatter {
  inline void print(std::ostream& os, bool val);
  inline void print(std::ostream& os, int16_t val);
  inline void print(std::ostream& os, uint16_t val);
  inline void print(std::ostream& os, int32_t val);
  inline void print(std::ostream& os, uint32_t val);
  inline void print(std::ostream& os, int64_t val);
  inline void print(std::ostream& os, uint64_t val);
  inline void print(std::ostream& os, float16_t val);
  inline void print(std::ostream& os, bfloat16_t val);
  inline void print(std::ostream& os, float val);
  inline void print(std::ostream& os, double val);
  inline void print(std::ostream& os, complex64_t val);

  bool capitalize_bool{false};
};

PrintFormatter& get_global_formatter();

/** Print the exception and then abort. */
void abort_with_exception(const std::exception& error);

/** Holds information about floating-point types. */
struct finfo {
  explicit finfo(Dtype dtype);
  Dtype dtype;
  double min;
  double max;
  double eps;
};

/** Holds information about integral types. */
struct iinfo {
  explicit iinfo(Dtype dtype);
  Dtype dtype;
  int64_t min;
  uint64_t max;
};

/** The type from promoting the arrays' types with one another. */
inline Dtype result_type(const array& a, const array& b) {
  return promote_types(a.dtype(), b.dtype());
}
inline Dtype result_type(const array& a, const array& b, const array& c) {
  return promote_types(result_type(a, b), c.dtype());
}
Dtype result_type(const std::vector<array>& arrays);

Shape broadcast_shapes(const Shape& s1, const Shape& s2);

/**
 * Returns the axis normalized to be in the range [0, ndim).
 */
int normalize_axis_index(
    int axis,
    int ndim,
    const std::string& msg_prefix = "");

std::ostream& operator<<(std::ostream& os, const Device& d);
std::ostream& operator<<(std::ostream& os, const Stream& s);
std::ostream& operator<<(std::ostream& os, const Dtype& d);
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
  return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
}
inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
  return os << static_cast<float>(v);
}
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
  return os << static_cast<float>(v);
}

template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
inline std::ostream& operator<<(std::ostream& os, const Vec& v) {
  os << "(";
  for (auto it = v.begin(); it != v.end(); ++it) {
    os << *it;
    if (it != std::prev(v.end())) {
      os << ",";
    }
  }
  os << ")";
  return os;
}

inline bool is_power_of_2(int n) {
  return ((n & (n - 1)) == 0) && n != 0;
}

inline int next_power_of_2(int n) {
  if (is_power_of_2(n)) {
    return n;
  }
  return pow(2, std::ceil(std::log2(n)));
}

namespace env {

int get_var(const char* name, int default_value);

inline int bfs_max_width() {
  static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20);
  return bfs_max_width_;
}

inline int max_ops_per_buffer(int default_value) {
  static int max_ops_per_buffer_ =
      get_var("MLX_MAX_OPS_PER_BUFFER", default_value);
  return max_ops_per_buffer_;
}

inline int max_mb_per_buffer(int default_value) {
  static int max_mb_per_buffer_ =
      get_var("MLX_MAX_MB_PER_BUFFER", default_value);
  return max_mb_per_buffer_;
}

inline bool metal_fast_synch() {
  static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0);
  return metal_fast_synch;
}

inline bool enable_tf32() {
  static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1);
  return enable_tf32_;
}

inline int nccl_timeout(int default_value) {
  static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value);
  return nccl_timeout;
}

} // namespace env

} // namespace mlx::core
